import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import Linear
from torch.distributions import Categorical
from itertools import accumulate
from model.gnn import GNN

class REINFORCE(nn.Module):
    def __init__(self, args):
        super(REINFORCE, self).__init__()
        self.args = args
        self.num_layers = args.policy_num_layers
        self.hidden_dim = args.hidden_dim
        self.gnn = GNN(args)
        self.layers = torch.nn.ModuleList()
        self.layers.append(nn.Linear(self.hidden_dim * 2, self.hidden_dim))
        for _ in range(self.num_layers - 2):
            self.layers.append(nn.Linear(self.hidden_dim, self.hidden_dim))
        self.layers.append(nn.Linear(self.hidden_dim, 1))
        self.reset_parameters()
    
    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        
    def forward(self, avai_op, data, greedy=False, T=1.0):
        data = self.gnn(data)
        score = torch.empty(size=(0, self.args.hidden_dim * 2)).to(self.args.device)
        for op_info in avai_op:
            score = torch.cat((score, torch.cat((data['m_x'][op_info['m_id']],
                                                data['op_x'][data['unfinish_op'].index(op_info['node_id'])]), dim=0).unsqueeze(0)), dim=0)
        for i in range(self.num_layers - 1):
            score = F.relu(self.layers[i](score))
        score = self.layers[self.num_layers - 1](score).t().squeeze()
        probs = F.softmax(score, dim=0)
        dist = Categorical(probs)
        if greedy == True:
            idx = torch.argmax(score)
        else:
            idx = dist.sample()
        return idx.item(), probs[idx].item(), dist.log_prob(idx), dist.entropy()
    
    def calculate_loss(self, device, log_probs, entropies, baselines, rewards):
        loss = []
        returns = torch.FloatTensor(list(accumulate(rewards[::-1]))[::-1]).to(device)
        policy_loss = 0.0
        entropy_loss = 0.0

        for log_prob, entropy, baseline, R in zip(log_probs, entropies, baselines, returns):
            if baseline == 0:
                advantage = R * -1
            else:
                advantage = ((R - baseline) / baseline) * -1

            loss.append(-log_prob * advantage - self.args.entropy_coef * entropy)
            policy_loss += log_prob * advantage
            entropy_loss += entropy

        return torch.stack(loss).mean(), policy_loss / len(log_probs), entropy_loss / len(log_probs)
         